{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note: we are in the process of updating the links in this notebook. If a link doesn't work, please open an issue and we'll rectify it ASAP. Thanks for your understanding!\n",
        "\n",
        "Links to add:\n",
        "* Cell 1: long-form, grounded summarisation blog post\n",
        "* Section 4: to text-rank method (context filtering)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TAjb3BOMsYjE"
      },
      "source": [
        "# Long-form summarization with citations using grounded generation\n",
        "\n",
        "This notebook provides the code to produce the outputs described in [this blog post](https://docs.google.com/document/d/1Eeakpz_FZoeMzJnQieqQWCpPtQuNiTGW4fueU9J0QHA/edit).\n",
        "\n",
        "## Table of contents\n",
        "\n",
        "1. [Setup](#setup)\n",
        "2. [Out-of-the-box summarization with Command-R](#out-of-the-box-summarization-with-command-r)\n",
        "3. [Introduce citations to the summary for grounding](#introduce-citations-to-the-summary-for-grounding)\n",
        "4. [Reduce the cost of summarization calls](#reduce-the-cost-of-summarization-calls)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zF9yEHUFt9up"
      },
      "source": [
        "<a id=\"setup\"></a>\n",
        "<a name=\"setup\"></a>\n",
        "## 1. Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x8CjM6c8EnVK",
        "outputId": "51acb907-2567-49f7-95df-85114dc975e3"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "# TODO: upgrade to \"cohere>5\"",
"!pip install \"cohere<5\" networkx\n",
        "\n",
        "import cohere\n",
        "import networkx as nx\n",
        "import nltk\n",
        "nltk.download(\"punkt\")\n",
        "from nltk.tokenize import sent_tokenize\n",
        "import numpy as np\n",
        "import spacy\n",
        "\n",
        "from collections import deque\n",
        "from getpass import getpass\n",
        "import re\n",
        "from typing import List, Tuple\n",
        "\n",
        "# Set up Cohere client\n",
        "co_api_key = getpass(\"Enter your Cohere API key: \")\n",
        "co_model = \"command-r\"\n",
        "co = cohere.Client(co_api_key)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "z2DBeeGTEnya",
        "outputId": "93a5f50a-f1db-445b-c0e3-2a619975da61"
      },
      "outputs": [],
      "source": [
        "# Read IMF report\n",
        "\n",
        "from google.colab import drive\n",
        "drive.mount(\"/content/drive\", force_remount=True)\n",
        "\n",
        "fpath = \"drive/Shareddrives/FDE/Cookbooks/Long-form summarisation/ai_and_future_of_work.txt\"\n",
        "with open(fpath, \"r\") as f:\n",
        "  text = f.read()\n",
        "\n",
        "num_tokens = co.tokenize(text).length\n",
        "print(f\"Loaded IMF report with {num_tokens} tokens\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xe2TzS6q28D6"
      },
      "source": [
        "### Aside: define utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eFkiLPTIzZw4"
      },
      "outputs": [],
      "source": [
        "# Utils!\n",
        "\n",
        "# --- for chunking ---\n",
        "def split_text_into_sentences(text: str) -> List[str]:\n",
        "    sentences =  sent_tokenize(text)\n",
        "    return sentences\n",
        "\n",
        "def group_sentences_into_passages(sentence_list: List[str], n_sentences_per_passage: int = 10):\n",
        "    \"\"\"\n",
        "    Group sentences into passages of n_sentences sentences.\n",
        "    \"\"\"\n",
        "    passages = []\n",
        "    passage = \"\"\n",
        "    for i, sentence in enumerate(sentence_list):\n",
        "        passage += sentence + \" \"\n",
        "        if (i + 1) % n_sentences_per_passage == 0:\n",
        "            passages.append(passage)\n",
        "            passage = \"\"\n",
        "    return passages\n",
        "\n",
        "def build_simple_chunks(text, n_sentences: int = 10):\n",
        "    \"\"\"\n",
        "    Build chunks of text from the input text.\n",
        "    \"\"\"\n",
        "    sentences = split_text_into_sentences(text)\n",
        "    chunks = group_sentences_into_passages(sentences, n_sentences_per_passage=n_sentences)\n",
        "    return chunks\n",
        "\n",
        "\n",
        "# --- for visualising citations ---\n",
        "\n",
        "def insert_citations(text: str, citations: List[dict]):\n",
        "    \"\"\"\n",
        "    A helper function to pretty print citations.\n",
        "    \"\"\"\n",
        "    offset = 0\n",
        "    # Process citations in the order they were provided\n",
        "    for citation in citations:\n",
        "        # Adjust start/end with offset\n",
        "        start, end = citation['start'] + offset, citation['end'] + offset\n",
        "        placeholder = \"[\" + \", \".join(doc[4:] for doc in citation[\"document_ids\"]) + \"]\"\n",
        "        # ^ doc[4:] removes the 'doc_' prefix, and leaves the quoted document\n",
        "        modification = f'{text[start:end]} {placeholder}'\n",
        "        # Replace the cited text with its bolded version + placeholder\n",
        "        text = text[:start] + modification + text[end:]\n",
        "        # Update the offset for subsequent replacements\n",
        "        offset += len(modification) - (end - start)\n",
        "\n",
        "    return text\n",
        "\n",
        "\n",
        "# --- for reducing number of tokens sent to model intelligently ---\n",
        "\n",
        "def textrank(text: str, co, max_tokens: int, n_sentences_per_passage: int) -> str:\n",
        "    \"\"\"\n",
        "    Shortens `text` by extracting key units of text from `text` based on their centrality and concatenating them.\n",
        "    The output is the concatenation of those key units, in their original order. Centrality is graph-theoretic\n",
        "    measure of connectedness of a node; the more connected a node is to surrounding nodes (and the more sparsely\n",
        "    those neighbours are connected), the higher centrality.\n",
        "\n",
        "    Key passages are identified via clustering in a three-step process:\n",
        "    1. Break up `long` into chunks (either sentences or passages, based on `unit`)\n",
        "    2. Embed each chunk using Cohere's embedding model and construct a similarity matrix\n",
        "    3. Compute the centrality of each chunk\n",
        "    4. Keep the highest-centrality chunks until `max_tokens` is reached\n",
        "    5. Put together shorterned text by reordering chunks in their original order\n",
        "\n",
        "    This approach is based on summarise.long_doc_summarization.extraction::extract_single_doc with sorting by\n",
        "    centrality. Adapted here because installing the `summarise` repo would have added a lot of unused functionalities\n",
        "    and dependencies.\n",
        "    \"\"\"\n",
        "\n",
        "    # 1. Chunk text into units\n",
        "    chunks = build_simple_chunks(text, n_sentences_per_passage)\n",
        "\n",
        "    # 2. Embed and construct similarity matrix\n",
        "    embeddings = np.array(\n",
        "        co.embed(\n",
        "            texts=chunks,\n",
        "            model=\"embed-english-v3.0\",\n",
        "            input_type=\"clustering\",\n",
        "        ).embeddings\n",
        "    )\n",
        "    similarities = np.dot(embeddings, embeddings.T)\n",
        "\n",
        "    # 3. Compute centrality and sort sentences by centrality\n",
        "    # Easiest to use networkx's `degree` function with similarity as weight\n",
        "    g = nx.from_numpy_array(similarities, edge_attr=\"weight\")\n",
        "    centralities = g.degree(weight=\"weight\")\n",
        "    idcs_sorted_by_centrality = [node for node, degree in sorted(centralities, key=lambda item: item[1], reverse=True)]\n",
        "\n",
        "    # 4. Add chunks back in order of centrality\n",
        "    selected = _add_chunks_by_priority(co, chunks, idcs_sorted_by_centrality, max_tokens)\n",
        "\n",
        "    # 5. Put condensed text back in original order\n",
        "    separator = \"\\n\"\n",
        "    short = separator.join([chunk for index, chunk in sorted(selected, key=lambda item: item[0], reverse=False)])\n",
        "\n",
        "    return short\n",
        "\n",
        "\n",
        "def _add_chunks_by_priority(\n",
        "    co, chunks: List[str], idcs_sorted_by_priority: List[int], max_tokens: int\n",
        ") -> List[Tuple[int, str]]:\n",
        "    \"\"\"\n",
        "    Given chunks of text and their indices sorted by priority (highest priority first), this function\n",
        "    fills the model context window with as many highest-priority chunks as possible.\n",
        "\n",
        "    The output is a list of (index, chunk) pairs, ordered by priority. To stitch back the chunks into\n",
        "    a cohesive text that preserves chronological order, sort the output on its index.\n",
        "    \"\"\"\n",
        "\n",
        "    selected = []\n",
        "    num_tokens = 0\n",
        "    idcs_queue = deque(idcs_sorted_by_priority)\n",
        "\n",
        "    while num_tokens < max_tokens and len(idcs_queue) > 0:\n",
        "        next_idx = idcs_queue.popleft()\n",
        "        num_tokens += co.tokenize(chunks[next_idx]).length - 2\n",
        "        # num_tokens += len(tokenizer.encode(chunks[next_idx]).ids) - 2\n",
        "        # ^ removing BOS and EOS tokens from count\n",
        "        selected.append((next_idx, chunks[next_idx]))\n",
        "        # ^ keep index and chunk, to reorder chronologically\n",
        "    if num_tokens > max_tokens:\n",
        "        selected.pop()\n",
        "\n",
        "    return selected\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PMt7yd7z3Gth"
      },
      "source": [
        "<a id=\"out-of-the-box-summarization-with-command-r\"></a>\n",
        "<a name=\"out-of-the-box-summarization-with-command-r\"></a>\n",
        "## 2. Out-of-the-box summarization with Command-R\n",
        "\n",
        "First, let's see Command-R's out-of-the-box performance. It's a 128k-context model, so we can pass the full IMF report in a single call. We replicate the exact instructions from the original tweet (correcting for a minor typo) for enabling fair comparisons."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nXsYVkFaEn5l",
        "outputId": "ae9f5b33-656a-4d55-d382-e834a88d79a5"
      },
      "outputs": [],
      "source": [
        "prompt_template = \"\"\"\\\n",
        "## text\n",
        "{text}\n",
        "\n",
        "## instructions\n",
        "Step 1. Read the entire text from the first to the last page.\n",
        "Step 2. Create a summary of every chapter from the first to the last page.\n",
        "\n",
        "## summary\n",
        "\"\"\"\n",
        "\n",
        "prompt = prompt_template.format(text=text)\n",
        "resp = co.chat(\n",
        "  message=prompt,\n",
        "  model=co_model,\n",
        "  temperature=0.3,\n",
        "  return_prompt=True\n",
        ")\n",
        "\n",
        "num_tokens_in = co.tokenize(resp.prompt).length\n",
        "num_tokens_out = resp.meta[\"billed_units\"][\"output_tokens\"]\n",
        "print(f\"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out\")\n",
        "print()\n",
        "print(\"--- Out-of-the-box summary with Command-R ---\")\n",
        "print()\n",
        "print(resp.text)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NLZsodOG41bb"
      },
      "source": [
        "<a id=\"introduce-citations-to-the-summary-for-grounding\"></a>\n",
        "<a name=\"introduce-citations-to-the-summary-for-grounding\"></a>\n",
        "## 3. Introduce citations to the summary for grounding\n",
        "\n",
        "When summarizing long documents, introducing citations is one simple method for checking the factuality of the summary without needing to read the full document.\n",
        "\n",
        "\n",
        "We've trained Command-R to introduce citations whenever prompted by our grounded generations instructions. Triggering this grounded mode is straightforward. Starting from the previous snippet, we only need to make two changes:\n",
        "1. Pass our text to the `documents` argument\n",
        "2. Pass our instructions to the `message` argument\n",
        "\n",
        "For more information on how to enable grounded generation via our `co.chat` API, please refer to our [documentation](https://docs.cohere.com/reference/chat).\n",
        "\n",
        "Finally, note that we chunk the IMF report into multiple documents before passing them to `co.chat`. This isn't necessary (`co.chat` annotates citations at the character level), but allows for more human-readable citations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0-nrG21VEn9a",
        "outputId": "0edfc184-32cc-4f84-a5a5-aafce84a6339"
      },
      "outputs": [],
      "source": [
        "summarize_preamble = \"\"\"\\\n",
        "You will receive a series of text fragments from an article that are presented in chronological order. \\\n",
        "As the assistant, you must generate responses to user's requests based on the information given in the fragments. \\\n",
        "Ensure that your responses are accurate and truthful, and that you reference your sources where appropriate to answer \\\n",
        "the queries, regardless of their complexity.\\\n",
        "\"\"\"\n",
        "\n",
        "instructions = \"\"\"\\\n",
        "## instructions\n",
        "Step 1. Read the entire text from the first to the last page.\n",
        "Step 2. Create a summary of every chapter from the first to the last page.\n",
        "\"\"\"\n",
        "\n",
        "# Chunk long text into multiple chunks for readable citations\n",
        "chunked = build_simple_chunks(text, n_sentences=30)\n",
        "# Use `message` and `documents` arguments to trigger grounded generation\n",
        "resp = co.chat(\n",
        "  preamble=summarize_preamble,\n",
        "  message=instructions,\n",
        "  documents=[{\"text\": chunk} for chunk in chunked],\n",
        "  model=co_model,\n",
        "  temperature=0.3,\n",
        "  return_prompt=True\n",
        ")\n",
        "# Note: the grounded generation pipeline takes longer when documents are chunked\n",
        "# more finely. For latency-sensitive applications, try tuning the size of chunks!\n",
        "\n",
        "num_tokens_in = co.tokenize(resp.prompt).length\n",
        "num_tokens_out = resp.meta[\"billed_units\"][\"output_tokens\"]\n",
        "print(f\"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out\")\n",
        "print()\n",
        "print(\"--- Summary with citations using grounded generation in Command-R ---\")\n",
        "print()\n",
        "print(resp.text)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7oWRUqAgUKfX"
      },
      "source": [
        "Let's display the citations inside our answer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vq-F26hCUJjl",
        "outputId": "f41f5d9e-c3d1-425b-e940-feea85d3797d"
      },
      "outputs": [],
      "source": [
        "print(insert_citations(resp.text, resp.citations))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7zC9Zt8ZXBZM"
      },
      "source": [
        "We can now visualise which section of the answer is based on which passage in the main text. Verifying factuality is straightforward: pick a section and verify that the relevant information is contained in the cited chunk.\n",
        "\n",
        "For instance, let's verify the statement\n",
        "```\n",
        "Around 40% of employment worldwide is exposed to AI [1, 6]\n",
        "```\n",
        "by checking its chunk:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ADq2XJhlUnEI",
        "outputId": "7aea0f41-cd8a-4132-f7a7-48482549c962"
      },
      "outputs": [],
      "source": [
        "print(chunked[6])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kQgmDucZYNUi"
      },
      "source": [
        "Seems convincing!\n",
        "By repeating such checks, it's straightforward to build trust in your summaries."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OnWyAyRI7U01"
      },
      "source": [
        "<a id=\"reduce-the-cost-of-summarization-calls\"></a>\n",
        "<a name=\"reduce-the-cost-of-summarization-calls\"></a>\n",
        "## 4. Reduce the cost of summarization calls\n",
        "\n",
        "Even though Command-R is an efficient, light-weight model, for some applications we may accept trading off some summarization quality for lower costs. To do this, we must reduce the amount of tokens sent to the model -- but how do we select the most relevant bits?\n",
        "\n",
        "We have a whole notebook dedicated to methods for reducing context length. Here, we call our 'text-rank' method to select maximally central chunks in a graph based on the chunk-to-chunk similarties. For more detail, please refer [to this cookbook](https://colab.research.google.com/drive/1zxSAbruOWwWJHNsj3N56uxZtUeiS7Evd)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4BZZ_cc1EoEA",
        "outputId": "1a077f8e-0363-48fe-c99a-6b4783b2982d"
      },
      "outputs": [],
      "source": [
        "# First, we filter the original text down to a smaller amount of tokens.\n",
        "# This will reduce cost and improve latency\n",
        "num_tokens = 8192\n",
        "shortened = textrank(text, co, num_tokens, n_sentences_per_passage=30)\n",
        "\n",
        "# Then, we apply grounded generation to keep citations on our (now shorter) report\n",
        "chunked = build_simple_chunks(shortened)\n",
        "resp = co.chat(\n",
        "  message=instructions,\n",
        "  documents=[{\"text\": chunk} for chunk in chunked],\n",
        "  model=co_model,\n",
        "  temperature=0.3,\n",
        "  return_prompt=True\n",
        ")\n",
        "\n",
        "num_tokens_in = co.tokenize(resp.prompt).length\n",
        "num_tokens_out = resp.meta[\"billed_units\"][\"output_tokens\"]\n",
        "print(f\"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out\")\n",
        "print()\n",
        "print(\"--- Summary with citations using text-rank + grounding in Command-R ---\")\n",
        "print()\n",
        "print(resp.text)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OBcnnR6MUG5z"
      },
      "source": [
        "The summary is looking convincing! In practice, the trade-off between cost-efficiency and performance should be considered carefully."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.10.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
